{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Understanding Neurogym Task\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neurogym/neurogym/blob/master/examples/understanding_neurogym_task.ipynb)\n",
    "\n",
    "This is a tutorial for understanding Neurogym task structure. Here we will go through\n",
    "\n",
    "1. Defining a basic gymnasium task\n",
    "2. Defining a basic trial-based neurogym task\n",
    "3. Adding observation and ground truth in neurogym tasks\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Installation on google colab\n",
    "\n",
    "Uncomment and execute cell below if running on google colab.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# # Install gymnasium\n",
    "# ! pip install gymnasium\n",
    "\n",
    "# # Install neurogym\n",
    "# ! git clone https://github.com/neurogym/neurogym.git\n",
    "# %cd neurogym/\n",
    "# ! pip install -e ."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gymnasium tasks\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Neurogym tasks follow basic [Gymnasium](https://gymnasium.farama.org/) tasks format. Gymnasium is a maintained fork of [OpenAI’s Gym](https://github.com/openai/gym) library. Each task is defined as a Python class, inheriting from the `gymnasium.Env` class.\n",
    "\n",
    "In this section we describe basic structure for an gymnasium task.\n",
    "\n",
    "In the `__init__` method, it is necessary to define two attributes, `self.observation_space` and `self.action_space` which describe the kind of spaces used by observations (network inputs) and actions (network outputs).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import gymnasium as gym\n",
    "\n",
    "\n",
    "class MyEnv(gym.Env):\n",
    "    def __init__(self):\n",
    "        super().__init__()  # Python boilerplate to initialize base class\n",
    "\n",
    "        # A two-dimensional box with minimum and maximum value set by low and high\n",
    "        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(2,))\n",
    "\n",
    "        # A discrete space with 3 possible values (0, 1, 2)\n",
    "        self.action_space = gym.spaces.Discrete(3)\n",
    "\n",
    "\n",
    "# Instantiate an environment\n",
    "env = MyEnv()\n",
    "print(\"Sample random observation value:\", env.observation_space.sample())\n",
    "print(\"Sample random action value:\", env.action_space.sample())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another key method that needs to be defined is the `step` method, which updates the environment and outputs observations and rewards after receiving the agent's action.\n",
    "\n",
    "The `step` method takes `action` as inputs, and outputs\n",
    "the agent's next observation `observation`,\n",
    "a scalar reward received by the agent `reward`,\n",
    "a boolean describing whether the environment needs to be reset `done`, and\n",
    "a dictionary holding any additional information `info`.\n",
    "\n",
    "If the environment is described by internal states, the `reset` method would reset these internal states. This method returns an initial observation `observation`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyEnv(gym.Env):\n",
    "    def __init__(self):\n",
    "        super().__init__()  # Python boilerplate to initialize base class\n",
    "        self.observation_space = gym.spaces.Box(low=-10.0, high=10.0, shape=(1,))\n",
    "        self.action_space = gym.spaces.Discrete(3)\n",
    "\n",
    "    def step(self, action):\n",
    "        ob = self.observation_space.sample()  # random sampling\n",
    "        reward = 1.0  # reward\n",
    "        terminated = False  # never ending\n",
    "        truncated = False\n",
    "        info = {}  # empty dictionary\n",
    "        return ob, reward, terminated, truncated, info\n",
    "\n",
    "    def reset(self):\n",
    "        ob = self.observation_space.sample()\n",
    "        return ob, {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we define a simple task where actions move an agent along a one-dimensional line. The reward is determined by the agent's location on this line.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def get_reward(x):\n",
    "    return np.sin(x) * np.exp(-np.abs(x) / 3)\n",
    "\n",
    "\n",
    "xs = np.linspace(-10, 10, 100)\n",
    "plt.plot(xs, get_reward(xs))\n",
    "plt.xlabel(\"State value (observation)\")\n",
    "plt.ylabel(\"Reward\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyEnv(gym.Env):\n",
    "    def __init__(self):\n",
    "        # A one-dimensional box with minimum and maximum value set by low and high\n",
    "        self.observation_space = gym.spaces.Box(low=-10.0, high=10.0, shape=(1,))\n",
    "\n",
    "        # A discrete space with 3 possible values (0, 1, 2)\n",
    "        self.action_space = gym.spaces.Discrete(3)\n",
    "\n",
    "        self.state = 0.0\n",
    "\n",
    "    def step(self, action):\n",
    "        # Actions 0, 1, 2 correspond to state change of -0.1, 0, +0.1\n",
    "        self.state += (action - 1.0) * 0.1\n",
    "        self.state = np.clip(self.state, -10, 10)\n",
    "\n",
    "        ob = self.state  # observation\n",
    "        reward = get_reward(self.state)  # reward\n",
    "        terminated = False  # never ending\n",
    "        truncated = False\n",
    "        info = {}  # empty dictionary\n",
    "        return ob, reward, terminated, truncated, info\n",
    "\n",
    "    def reset(self):\n",
    "        # Re-initialize state\n",
    "        self.state = self.observation_space.sample()\n",
    "        return self.state, {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An agent can interact with the environment iteratively.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = MyEnv()\n",
    "ob, _ = env.reset()\n",
    "ob_log = list()\n",
    "reward_log = list()\n",
    "for i in range(10000):\n",
    "    action = env.action_space.sample()  # A random agent\n",
    "    ob, reward, terminated, truncated, info = env.step(action)\n",
    "    ob_log.append(ob)\n",
    "    reward_log.append(reward)\n",
    "\n",
    "plt.plot(ob_log, reward_log)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Trial-based Neurogym Tasks\n",
    "\n",
    "Many neuroscience and cognitive science tasks have trial structure. `neurogym.TrialEnv` provides a class for common trial-based tasks. Its main difference from `gymnasium.Env` is the `_new_trial()` method that generates abstract information about a new trial, and optionally, the observation and ground-truth output. Additionally, users provide a `_step()` method instead of `step()`.\n",
    "\n",
    "The `_new_trial()` method takes any key-word arguments (`**kwargs`), and outputs a dictionary `trial` containing relevant information about this trial. This dictionary is accesible during `_step` as `self.trial`.\n",
    "\n",
    "Here we define a simple task where the agent needs to make a binary decision on every trial based on its observation. Each trial is only one time step.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import neurogym as ngym\n",
    "from neurogym import TrialEnv\n",
    "\n",
    "\n",
    "class MyTrialEnv(TrialEnv):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.observation_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(1,))\n",
    "        self.action_space = gym.spaces.Discrete(2)\n",
    "\n",
    "        self.next_ob = np.random.uniform(-1, 1, size=(1,))\n",
    "\n",
    "    def _new_trial(self):\n",
    "        ob = self.next_ob  # observation previously computed\n",
    "        # Sample observation for the next trial\n",
    "        self.next_ob = np.random.uniform(-1, 1, size=(1,))\n",
    "\n",
    "        trial = dict()\n",
    "        # Ground-truth is 1 if ob > 0, else 0\n",
    "        trial[\"ground_truth\"] = (ob > 0) * 1.0\n",
    "\n",
    "        return trial\n",
    "\n",
    "    def _step(self, action):\n",
    "        ob = self.next_ob\n",
    "        # If action equals to ground_truth, reward=1, otherwise 0\n",
    "        reward = (action == self.trial[\"ground_truth\"]) * 1.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        info = {\"new_trial\": True}\n",
    "        return ob, reward, terminated, truncated, info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = MyTrialEnv()\n",
    "ob, _ = env.reset()\n",
    "\n",
    "print(\"Trial\", 0)\n",
    "print(\"Received observation\", ob)\n",
    "\n",
    "for i in range(5):\n",
    "    action = env.action_space.sample()  # A random agent\n",
    "    print(\"Selected action\", action)\n",
    "    ob, reward, terminated, truncated, info = env.step(action)\n",
    "    print(\"Received reward\", reward)\n",
    "    print(\"Trial\", i + 1)\n",
    "    print(\"Received observation\", ob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Including time, period, and observation in trial-based tasks\n",
    "\n",
    "Most neuroscience and cognitive science tasks follow additional temporal structures that are incorporated into `neurogym.TrialEnv`. These tasks typically\n",
    "\n",
    "1. Are described in real time instead of discrete time steps. For example, the task can last 3 seconds.\n",
    "2. Contain multiple time periods in each trial, such as a stimulus period and a response period.\n",
    "\n",
    "To include these features, neurogym tasks typically support setting the time length of each step in `dt` (in ms), and the time length of each time period in `timing`.\n",
    "\n",
    "For example, consider the following binary decision-making task with a 500ms stimulus period, followed by a 500ms decision period. The periods are added to each trial through `self.add_period()` in `self._new_trial()`. During `_step()`, you can check which period the task is currently in with `self.in_period(period_name)`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDecisionEnv(TrialEnv):\n",
    "    def __init__(self, dt=100, timing=None):\n",
    "        super().__init__(dt=dt)  # dt is passed to base task\n",
    "\n",
    "        # Setting default task timing\n",
    "        self.timing = {\"stimulus\": 500, \"decision\": 500}\n",
    "        # Update timing if provided externally\n",
    "        if timing:\n",
    "            self.timing.update(timing)\n",
    "\n",
    "        self.observation_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(1,))\n",
    "        self.action_space = gym.spaces.Discrete(2)\n",
    "\n",
    "    def _new_trial(self):\n",
    "        # Setting time periods for this trial\n",
    "        periods = [\"stimulus\", \"decision\"]\n",
    "        # Will add stimulus and decision periods sequentially using self.timing info\n",
    "        self.add_period(periods)\n",
    "\n",
    "        # Sample observation for the next trial\n",
    "        stimulus = np.random.uniform(-1, 1, size=(1,))\n",
    "\n",
    "        trial = dict()\n",
    "        trial[\"stimulus\"] = stimulus\n",
    "        # Ground-truth is 1 if stimulus > 0, else 0\n",
    "        trial[\"ground_truth\"] = (stimulus > 0) * 1.0\n",
    "\n",
    "        return trial\n",
    "\n",
    "    def _step(self, action):\n",
    "        # Check if the current time step is in stimulus period\n",
    "        if self.in_period(\"stimulus\"):\n",
    "            ob = np.array([self.trial[\"stimulus\"]])\n",
    "            reward = 0.0  # no reward\n",
    "        else:\n",
    "            ob = np.array([0.0])  # no observation\n",
    "            # If action equals to ground_truth, reward=1, otherwise 0\n",
    "            reward = (action == self.trial[\"ground_truth\"]) * 1.0\n",
    "\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        # By default, the trial is not ended\n",
    "        info = {\"new_trial\": False}\n",
    "        return ob, reward, terminated, truncated, info"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Running the environment with a random agent and plotting the agent's observation, action, and rewards\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Logging\n",
    "log = {\"ob\": [], \"gt\": [], \"action\": [], \"reward\": []}\n",
    "\n",
    "env = MyDecisionEnv(dt=100)\n",
    "ob, _ = env.reset()\n",
    "log[\"ob\"].append(float(ob))\n",
    "log[\"gt\"].append(float(ob > 0))\n",
    "for i in range(30):\n",
    "    action = env.action_space.sample()  # A random agent\n",
    "    ob, reward, terminated, truncated, info = env.step(action)\n",
    "\n",
    "    log[\"action\"].append(float(action))\n",
    "    log[\"ob\"].append(float(ob))\n",
    "    log[\"gt\"].append(float(ob > 0))\n",
    "    log[\"reward\"].append(float(reward))\n",
    "\n",
    "log[\"ob\"] = log[\"ob\"][:-1]  # exclude last observation\n",
    "log[\"gt\"] = log[\"gt\"][:-1]  # exclude last observation\n",
    "# Visualize\n",
    "f, axes = plt.subplots(len(log), 1, sharex=True)\n",
    "for ax, key in zip(axes, log):\n",
    "    ax.plot(log[key], \".-\")\n",
    "    ax.set_ylabel(key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting observation and ground-truth at the beginning of each trial\n",
    "\n",
    "In many tasks, the observation and ground-truth are pre-determined for each trial, and can be set in `self._new_trial()`. The generated observation and ground-truth can then be used as inputs and targets for supervised learning.\n",
    "\n",
    "Observation and ground_truth can be set in `self._new_trial()` with the `self.add_ob()` and `self.set_groundtruth` methods. Users can specify the period and location of the observation using their names. For example, `self.add_ob(1, period='stimulus', where='fixation')`.\n",
    "\n",
    "This allows the users to access the observation and groundtruth of the entire trial with `self.ob` and `self.gt`, and access their values with `self.ob_now` and `self.gt_now`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDecisionEnv(TrialEnv):\n",
    "    def __init__(self, dt=100, timing=None):\n",
    "        super().__init__(dt=dt)  # dt is passed to base task\n",
    "\n",
    "        # Setting default task timing\n",
    "        self.timing = {\"stimulus\": 500, \"decision\": 500}\n",
    "        # Update timing if provided externally\n",
    "        if timing:\n",
    "            self.timing.update(timing)\n",
    "\n",
    "        # Here we use ngym.spaces, which allows setting name of each dimension\n",
    "        name = {\"fixation\": 0, \"stimulus\": 1}\n",
    "        self.observation_space = ngym.spaces.Box(\n",
    "            low=-1.0, high=1.0, shape=(2,), name=name\n",
    "        )\n",
    "        name = {\"fixation\": 0, \"choice\": [1, 2]}\n",
    "        self.action_space = ngym.spaces.Discrete(3, name=name)\n",
    "\n",
    "    def _new_trial(self):\n",
    "        # Setting time periods for this trial\n",
    "        periods = [\"stimulus\", \"decision\"]\n",
    "        # Will add stimulus and decision periods sequentially using self.timing info\n",
    "        self.add_period(periods)\n",
    "\n",
    "        # Sample observation for the next trial\n",
    "        stimulus = np.random.uniform(-1, 1, size=(1,))\n",
    "\n",
    "        # Add value 1 to stimulus period at fixation location\n",
    "        self.add_ob(1, period=\"stimulus\", where=\"fixation\")\n",
    "        # Add value stimulus to stimulus period at stimulus location\n",
    "        self.add_ob(stimulus, period=\"stimulus\", where=\"stimulus\")\n",
    "\n",
    "        # Set ground_truth\n",
    "        groundtruth = int(stimulus > 0)\n",
    "        self.set_groundtruth(groundtruth, period=\"decision\", where=\"choice\")\n",
    "\n",
    "        trial = dict()\n",
    "        trial[\"stimulus\"] = stimulus\n",
    "        trial[\"ground_truth\"] = groundtruth\n",
    "\n",
    "        return trial\n",
    "\n",
    "    def _step(self, action):\n",
    "        # self.ob_now and self.gt_now correspond to\n",
    "        # current step observation and groundtruth\n",
    "\n",
    "        # If action equals to ground_truth, reward=1, otherwise 0\n",
    "        reward = (action == self.gt_now) * 1.0\n",
    "\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        # By default, the trial is not ended\n",
    "        info = {\"new_trial\": False}\n",
    "        return self.ob_now, reward, terminated, truncated, info"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sampling one trial. The trial observation and ground-truth can be used for supervised learning.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = MyDecisionEnv()\n",
    "env.reset()\n",
    "\n",
    "trial = env.new_trial()\n",
    "ob, gt = env.ob, env.gt\n",
    "\n",
    "print(\"Trial information\", trial)\n",
    "print(\"Observation shape is (N_time, N_unit) =\", ob.shape)\n",
    "print(\"Groundtruth shape is (N_time,) =\", gt.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Visualizing the environment with a helper function.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run the environment for 2 trials using a random agent.\n",
    "fig = ngym.utils.plot_env(\n",
    "    env,\n",
    "    ob_traces=[\"stimulus\", \"fixation\"],\n",
    "    num_trials=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### An example perceptual decision-making task\n",
    "\n",
    "Using the above style, we can define a simple perceptual decision-making task (the PerceptualDecisionMaking task from neurogym).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PerceptualDecisionMaking(ngym.TrialEnv):\n",
    "    \"\"\"Two-alternative forced choice task in which the subject has to\n",
    "    integrate two stimuli to decide which one is higher on average.\n",
    "\n",
    "    Args:\n",
    "        stim_scale: Controls the difficulty of the experiment. (def: 1., float)\n",
    "        sigma: float, input noise level\n",
    "        dim_ring: int, dimension of ring input and output\n",
    "    \"\"\"\n",
    "\n",
    "    metadata = {\n",
    "        \"paper_link\": \"https://www.jneurosci.org/content/12/12/4745\",\n",
    "        \"paper_name\": \"\"\"The analysis of visual motion: a comparison of\n",
    "        neuronal and psychophysical performance\"\"\",\n",
    "        \"tags\": [\"perceptual\", \"two-alternative\", \"supervised\"],\n",
    "    }\n",
    "\n",
    "    def __init__(\n",
    "        self, dt=100, rewards=None, timing=None, stim_scale=1.0, sigma=1.0, dim_ring=2\n",
    "    ):\n",
    "        super().__init__(dt=dt)\n",
    "        # The strength of evidence, modulated by stim_scale\n",
    "        self.cohs = np.array([0, 6.4, 12.8, 25.6, 51.2]) * stim_scale\n",
    "        self.sigma = sigma / np.sqrt(self.dt)  # Input noise\n",
    "\n",
    "        # Rewards\n",
    "        self.rewards = {\"abort\": -0.1, \"correct\": +1.0, \"fail\": 0.0}\n",
    "        if rewards:\n",
    "            self.rewards.update(rewards)\n",
    "\n",
    "        self.timing = {\"fixation\": 100, \"stimulus\": 2000, \"delay\": 0, \"decision\": 100}\n",
    "        if timing:\n",
    "            self.timing.update(timing)\n",
    "\n",
    "        self.abort = False\n",
    "\n",
    "        self.theta = np.linspace(0, 2 * np.pi, dim_ring + 1)[:-1]\n",
    "        self.choices = np.arange(dim_ring)\n",
    "\n",
    "        name = {\"fixation\": 0, \"stimulus\": range(1, dim_ring + 1)}\n",
    "        self.observation_space = ngym.spaces.Box(\n",
    "            -np.inf, np.inf, shape=(1 + dim_ring,), dtype=np.float32, name=name\n",
    "        )\n",
    "        name = {\"fixation\": 0, \"choice\": range(1, dim_ring + 1)}\n",
    "        self.action_space = ngym.spaces.Discrete(1 + dim_ring, name=name)\n",
    "\n",
    "    def _new_trial(self, **kwargs):\n",
    "        # Trial info\n",
    "        trial = {\n",
    "            \"ground_truth\": self.rng.choice(self.choices),\n",
    "            \"coh\": self.rng.choice(self.cohs),\n",
    "        }\n",
    "        trial.update(kwargs)\n",
    "\n",
    "        coh = trial[\"coh\"]\n",
    "        ground_truth = trial[\"ground_truth\"]\n",
    "        stim_theta = self.theta[ground_truth]\n",
    "\n",
    "        # Periods\n",
    "        self.add_period([\"fixation\", \"stimulus\", \"delay\", \"decision\"])\n",
    "\n",
    "        # Observations\n",
    "        self.add_ob(1, period=[\"fixation\", \"stimulus\", \"delay\"], where=\"fixation\")\n",
    "        stim = np.cos(self.theta - stim_theta) * (coh / 200) + 0.5\n",
    "        self.add_ob(stim, \"stimulus\", where=\"stimulus\")\n",
    "        self.add_randn(0, self.sigma, \"stimulus\", where=\"stimulus\")\n",
    "\n",
    "        # Ground truth\n",
    "        self.set_groundtruth(ground_truth, period=\"decision\", where=\"choice\")\n",
    "\n",
    "        return trial\n",
    "\n",
    "    def _step(self, action):\n",
    "        new_trial = False\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        # rewards\n",
    "        reward = 0\n",
    "        gt = self.gt_now\n",
    "        # observations\n",
    "        if self.in_period(\"fixation\"):\n",
    "            if action != 0:  # action = 0 means fixating\n",
    "                new_trial = self.abort\n",
    "                reward += self.rewards[\"abort\"]\n",
    "        elif self.in_period(\"decision\"):\n",
    "            if action != 0:\n",
    "                new_trial = True\n",
    "                if action == gt:\n",
    "                    reward += self.rewards[\"correct\"]\n",
    "                    self.performance = 1\n",
    "                else:\n",
    "                    reward += self.rewards[\"fail\"]\n",
    "\n",
    "        return (\n",
    "            self.ob_now,\n",
    "            reward,\n",
    "            terminated,\n",
    "            truncated,\n",
    "            {\"new_trial\": new_trial, \"gt\": gt},\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = PerceptualDecisionMaking()\n",
    "fig = ngym.utils.plot_env(\n",
    "    env,\n",
    "    ob_traces=[\"Stim1\", \"Stim2\", \"fixation\"],\n",
    "    num_trials=2,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ngym",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
